Extended model (empirical)

We load a few packages and functions.
# Packages
  library(tidyverse)
  library(tidybayes)
  library(rethinking)
  library(patchwork)
  library(cmdstanr)
  library(GGally)

# dplyr options
  options(dplyr.summarise.inform = FALSE)
  
# Functions
  post_plot <- readRDS("./functions/post_plot.rds")
  post_plot_2 <- readRDS("./functions/post_plot_2.rds")
  plot_interval <- readRDS("./functions/plot_interval.rds")
  hist_plot <- readRDS("./functions/hist_plot.rds")
  binom_pmf <- readRDS("./functions/binom_pmf.rds")
  hist_plot_simple <- readRDS("./functions/hist_plot_simple.rds")
  
# Empirical data
  d <- readRDS("./empirical_data/d_stan_format.stan")
  
# Stan model
  m_1 <- cmdstan_model("./stan_models/extended_model.stan")

We fit the data d to the model m_1:

post_1 <- m_1$sample(
    data = d,
    iter_warmup = 1e3,
    iter_sampling = 1e3,
    chains = 8,
    parallel_chains = 8
  )

# Save posterior
post_1 %>%
  saveRDS("./fitted_models/02.2/post_empirical_1_raw.rds")

# Extract & reformat draws
post_1_draws <- post_1 %>% tidy_draws()
post_1_draws %>%
  saveRDS("./fitted_models/02.2/post_empirical_1_draws.rds")


# MCMC diagnostics
diagnostics_1 <- post_1$summary()
diagnostics_1 %>%
    saveRDS("./fitted_models/02.2/diagnostics_empirical_1.rds")

And import the MCMC draws and diagnostics:

post <- readRDS("./fitted_models/02.2/post_empirical_1_draws.rds")
diagnostics <- readRDS("./fitted_models/02.2/diagnostics_empirical_1.rds")

1 Diagnostics

# Assume `fit` is your cmdstanr object
Rhats <- diagnostics$rhat
ess_bulk <- diagnostics$ess_bulk
ess_tail <- diagnostics$ess_tail
Rhats %>%
  discard(is.na) %>%
  hist_plot_simple(min = 0.999,
                  max = 1.015,
                   breaks = 50,
                   fill = "#acbfb7",
                   linewidth = 0.5)

ess_bulk %>%
  discard(is.na) %>%
  hist_plot_simple(max = max(ess_bulk, na.rm = T) + 1000,
                   fill = "#acbfb7",
                   breaks = 50,
                   linewidth = 0.5)  +
     geom_vline(xintercept = 250, linewidth = 0.4, linetype = "dotted",
                color = "#ccccc0")

ess_tail %>%
  discard(is.na) %>%
  hist_plot_simple(max = max(ess_tail, na.rm = T) + 1000,
                   fill = "#acbfb7",
                   breaks = 50,
                   linewidth = 0.5)  +
     geom_vline(xintercept = 250, linewidth = 0.4, linetype = "dotted",
                color = "#ccccc0")

A few traceplots:

Show code:
post %>%
  select(`alpha_1[1]`, `delta`, `proba`, `alpha_2[1]`, `alpha_gr[1]`,
         `alpha_1_gr[1]`, `alpha_2_gr[1]`, `alpha_gr_1[1]`, `alpha_gr_2[1]`,
         `Omega_1[1,1]`, `Omega_2[1,1]`, `Omega_gr[1,1]`,
         `c_ind[1,2]`, `c_dyad[1,2]`, `phi_1[1]`, `phi_2[1]`, `sigma_phi[1]`,
         `sigma_tau[1]`, `phi_give_gr[1]`, `phi_rec_gr[1]`, `tau_1[1]`, `tau_2[1]`,
         `tau_gr_ab[1]`,`tau_gr_ba[1]`, `.iteration`, `.chain`) %>%
  gather(param, value, 1:(ncol(.)-2)) %>%
  filter(`.iteration` < 501) %>%
  
  ggplot(aes(x = `.iteration`, y = value, color = as.factor(`.chain`))) +
  geom_line(alpha = 0.8, linewidth = 0.4) +
  facet_wrap(~ param, scale = "free", ncol = 6) +
  theme_bw() +
  theme(
    axis.text.x = element_text(vjust = 0),
    axis.ticks.y = element_blank(),
    panel.grid = element_blank(),
    legend.position = "none",
    panel.grid.major.y = element_line(
      color = "#ccccc0",
      linewidth = 0.3,
      linetype = "dotted"
    ),
    panel.border = element_rect(fill = "transparent", color = "#ccccc0"),
    panel.spacing = unit(1, "lines"),
    strip.background = element_rect(fill = "white", color = "white"),
    plot.margin = margin(0, 1, 0, 0, "cm")
  ) +
  labs(y = "", x = "") +
  scale_color_manual(values = c("#304551", "#42535e", "#54616d", "#66727d",
                                "#8d99a3", "#a4b0b8","#bdc9d1", "#dce9ee")) +
  scale_x_continuous(breaks = c(0, 500))

2 Expected holding times and contrasts

Estimated average holding times in each state per sex combination:

post <- post %>%
  mutate(iter = 1)

p1 <- post %>%
  select(`avg_s1[1]`, iter) %>%
  mutate(value = `avg_s1[1]` / (60 * 12)) %>%
  post_plot_2(
    min = 0,
    max = 40,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p2 <- post %>%
  select(`avg_s1[2]`, iter) %>%
  mutate(value = `avg_s1[2]` / (60 * 12)) %>%
  post_plot_2(
    min = 0,
    max = 40,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p3 <- post %>%
  select(`avg_s1[3]`, iter) %>%
  mutate(value = `avg_s1[3]` / (60 * 12)) %>%
  post_plot_2(
    min = 0,
    max = 40,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

p4 <- post %>%
  select(`avg_s2[1]`, iter) %>%
  mutate(value = `avg_s2[1]`) %>%
  post_plot_2(
    min = 0,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p5 <- post %>%
  select(`avg_s2[2]`, iter) %>%
  mutate(value = `avg_s2[2]`) %>%
  post_plot_2(
    min = 0,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p6 <- post %>%
  select(`avg_s2[3]`, iter) %>%
  mutate(value = `avg_s2[3]`) %>%
  post_plot_2(
    min = 0,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

p7 <- post %>%
  select(`avg_gr[1]`, iter) %>%
  mutate(value = `avg_gr[1]`) %>%
  post_plot_2(
    min = 0,
    max = 10,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p8 <- post %>%
  select(`avg_gr[2]`, iter) %>%
  mutate(value = `avg_gr[2]`) %>%
  post_plot_2(
    min = 0,
    max = 10,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p9 <- post %>%
  select(`avg_gr[3]`, iter) %>%
  mutate(value = `avg_gr[3]`) %>%
  post_plot_2(
    min = 0,
    max = 10,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p10 <- post %>%
  select(`avg_gr[4]`, iter) %>%
  mutate(value = `avg_gr[4]`) %>%
  post_plot_2(
    min = 0,
    max = 10,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

design <-
"ADG
 BEH
 BEI
 CFJ"

p1 + p2 + p3 + p4 + p5 + 
p6 + p7 + p8 + p9 + p10 +
  plot_layout(design = design)

Marginal contrasts:

p1 <- post %>%
  select(`ATE_s1[1]`, iter) %>%
  mutate(value = `ATE_s1[1]` / (60 * 12)) %>%
  post_plot_2(
    min = -12,
    max = 12,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p2 <- post %>%
  select(`ATE_s1[2]`, iter) %>%
  mutate(value = `ATE_s1[2]` / (60 * 12)) %>%
  post_plot_2(
    min = -12,
    max = 12,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

p3 <- post %>%
  select(`ATE_s2[1]`, iter) %>%
  mutate(value = `ATE_s2[1]`) %>%
  post_plot_2(
    min = -4,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p4 <- post %>%
  select(`ATE_s2[2]`, iter) %>%
  mutate(value = `ATE_s2[2]`) %>%
  post_plot_2(
    min = -4,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

p5 <- post %>%
  select(`ATE_gr[1]`, iter) %>%
  mutate(value = `ATE_gr[1]`) %>%
  post_plot_2(
    min = -4,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p6 <- post %>%
  select(`ATE_gr[2]`, iter) %>%
  mutate(value = `ATE_gr[2]`) %>%
  post_plot_2(
    min = -4,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )  +
  theme(axis.text.x = element_blank(),
        axis.ticks.x = element_blank())

p7 <- post %>%
  select(`ATE_gr[3]`, iter) %>%
  mutate(value = `ATE_gr[3]`) %>%
  post_plot_2(
    min = -4,
    max = 4,
    lower_bound = 0,
    target_values = 0,
    alpha_filling = 0.3,
    alpha_outline = 0.8
  )

design <-
"ACE
 ACF
 BDG"

p1 + p2 + p3 + 
p4 + p5 + p6 + p7 +
  plot_layout(design = design)

3 Posterior predictions

3.1 Holding times

# Observed holding times for j in {1, ..., J}
  obs_ht <- tibble(
    dyad = d$B_dyad,
    s = d$B_s,
    x = d$B_x,
    c = d$B_c
  )

# theta posterior samples
thetas <- post %>%
  select(starts_with("theta["), delta, proba) %>% 
  slice(1:30) %>%
  mutate(sample = 1:nrow(.)) %>%
  gather(param, value, 1:(ncol(.) - 3)) %>%
  extract(param, into = c("dyad", "s"), regex = "theta\\[(\\d+),(\\d+)\\]", convert = TRUE)



post_pred_x <- list()
for (spl in 1:30) {
  post_pred_x[[spl]] <- obs_ht %>%
  select(-x) %>%
  left_join(thetas %>% filter(sample == spl)) %>%
  mutate(long = rbinom(nrow(.), 1, proba),
         value = ifelse(s == 1 & long == 0, delta, value)) %>%
    mutate(x = rexp(nrow(.), 1 / value)) %>%
    filter(x <= 40)
}

# We bin the observed holding times for plotting
pp_x_tbl <- post_pred_x %>%
  bind_rows() %>%
  # Binning sequence
  mutate(bin = sapply(x, function(val)
    seq(2.5, 37.5, by = 5)[which.min(abs(seq(2.5, 37.5, by = 5) - val))])) %>%
  group_by(s, sample, bin) %>%
  summarise(count = n()) %>%
  group_by(s, sample) %>%
  # Fill in the zeros
  complete(bin = seq(2.5, 37.5, by = 5), fill = list(count = 0)) %>%
  ungroup()

# Plots
p1 <- obs_ht %>% filter(s == 1 & c == 0) %>%
hist_plot(
    colbin = "#c9c7bd",
    posterior = 1,
    post_data = pp_x_tbl %>%
      filter(s == 1),
    alpha_point = 0.6,
    alpha_line = 0.4,
    alpha_hist = 0.4,
    col_post = "#c9c7bd"
  )

p2 <- obs_ht %>% filter(s == 2 & c == 0) %>%
hist_plot(
    colbin = "#E9C4C1",
    posterior = 1,
    post_data = pp_x_tbl %>%
      filter(s == 2),
    alpha_point = 0.6,
    alpha_line = 0.4,
    alpha_hist = 0.4,
    col_post = "#d1a7a3"
  )

p3 <- obs_ht %>% filter(s == 3 & c == 0) %>%
hist_plot(
    colbin = "#ce8da6",
    posterior = 1,
    post_data = pp_x_tbl %>%
      filter(s == 3),
    alpha_point = 0.6,
    alpha_line = 0.4,
    alpha_hist = 0.4,
    col_post = "#a37184"
  )

p4 <- obs_ht %>% filter(s == 4 & c == 0) %>%
hist_plot(
    colbin = "#996282",
    posterior = 1,
    post_data = pp_x_tbl %>%
      filter(s == 4),
    alpha_point = 0.6,
    alpha_line = 0.4,
    alpha_hist = 0.4,
    col_post = "#7a4e68"
  )

3.2 Transitions

obs_tr <- tibble(
dyad = d$C_dyad,
s_from = d$C_s_from
)

# gamma posterior samples
gammas <- post %>%
  select(starts_with("Gamma[")) %>% 
  slice(1:30) %>%
  mutate(sample = 1:nrow(.)) %>%
  gather(param, value, 1:(ncol(.) - 1)) %>%
  extract(param, into = c("dyad", "s_from", "s_to"), 
          regex = "Gamma\\[(\\d+),(\\d+),(\\d+)\\]", convert = TRUE) %>%
  pivot_wider(names_from = s_to, 
              values_from = value, 
              names_prefix = "s_to_")


post_pred_tr <- list()
for (spl in 1:30) {
  post_pred_tr[[spl]] <- obs_tr %>%
  left_join(gammas %>% filter(sample == spl)) %>%
  mutate(s_to = mapply(function(s1, s2, s3, s4) sample(1:4, size = 1, prob = c(s1, s2, s3, s4)),
                      s_to_1, s_to_2, s_to_3, s_to_4))
  post_pred_tr[[spl]] <- post_pred_tr[[spl]] %>%
  group_by(s_from, s_to) %>%
  summarise(count = n()) %>%
    mutate(sample = spl)
}

post_pred_tr_tbl <- post_pred_tr %>%
  bind_rows() %>%
  mutate(s_to = as.factor(s_to))

obs_tr_smry <- obs_tr %>%
  mutate(s_to = as.factor(d$C_s_to)) %>%
  group_by(s_from, s_to) %>%
  summarise(count = n()) %>%
  ungroup()

p5 <- binom_pmf(
  data = obs_tr_smry %>% filter(s_from == 1),
  col_data = c("#E9C4C1", "#ce8da6", "#996282"),
  col_post = c("#E9C4C1", "#ce8da6", "#996282"),
  posterior = 1,
  post_data = post_pred_tr_tbl %>% filter(s_from == 1)
)

p6 <- binom_pmf(
  data = obs_tr_smry %>% filter(s_from == 2),
  col_data = c("#c9c7bd", "#ce8da6", "#996282"),
  col_post = c("#ce8da6","#996282", "#c9c7bd"),
  posterior = 1,
  post_data = post_pred_tr_tbl %>% filter(s_from == 2)
)

p7 <- binom_pmf(
  data = obs_tr_smry %>% filter(s_from == 3),
  col_data = c("#c9c7bd", "#E9C4C1", "#996282"),
  col_post = c("#E9C4C1", "#996282", "#c9c7bd"),
  posterior = 1,
  post_data = post_pred_tr_tbl %>% filter(s_from == 3)
)

p8 <- binom_pmf(
  data = obs_tr_smry %>% filter(s_from == 4),
  col_data = c("#c9c7bd", "#E9C4C1", "#ce8da6"),
  col_post = c("#E9C4C1", "#ce8da6", "#c9c7bd"),
  posterior = 1,
  post_data = post_pred_tr_tbl %>% filter(s_from == 4)
)

wrap_plots(p1, p5, p2, p6, 
           p3, p7, p4, p8, ncol = 2)